import torch
import numpy as np
import torch.nn.functional as F
import torch.nn as nn
from tqdm.auto import tqdm

from diffusers import UNet2DModel
from torch import optim
from chip.utils.fourier import fft_2D, ifft_2D
from chip.gridrec.gridRecAUS import gridRec, gridRecAUS
from chip.utils.sinogram import compute_sinogram, Sinogram
from chip.utils.utils import create_circle_filter, find_match_indices


class TomographicStableDiffusion(nn.Module):
    def __init__(
            self, image_shape, unet: UNet2DModel,
            vae, tokenizer, text_encoder):
        super(TomographicStableDiffusion, self).__init__()
        self.unet = unet.requires_grad_(False).eval()
        self.vae = vae.requires_grad_(False).eval()
        self.tokenizer = tokenizer
        self.text_encoder = text_encoder.requires_grad_(False).eval()
        self.decoded_x_0_pred = nn.Parameter(torch.zeros(image_shape))

    def get_img(self):
        if self.use_sigmoid:
            image = 5 * (self.decoded_x_0_pred - 0.5)
            image = torch.sigmoid(image)
        else:
            image = self.decoded_x_0_pred

        self.computed_image = image
        return image

    def step(self, text_embeddings, t, guidance_scale, x_t, noise_scheduler):
        device = self.decoded_x_0_pred.device
        noise_scheduler._init_step_index(t)
        x_0_pred, noise_pred = self.predict_x_0(text_embeddings, guidance_scale, noise_scheduler, t, x_t)
        new_x_t = noise_scheduler.step(noise_pred, t, x_t).prev_sample

        return new_x_t, noise_pred

    def predict_x_0(self, text_embeddings, guidance_scale, noise_scheduler, t, latents):
        torch_device = self.decoded_x_0_pred.device
        latent_model_input = torch.cat([latents] * 2)

        latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # perform guidance
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

        # compute the previous noisy sample x_t -> x_0
        sigma = noise_scheduler.sigmas[noise_scheduler.step_index]
        latents = latents - sigma * noise_pred

        return latents, noise_pred

    def get_text_embedding(self, prompt):
        torch_device = self.decoded_x_0_pred.device
        text_input = self.tokenizer(prompt, padding="max_length", max_length=self.tokenizer.model_max_length,
                                    truncation=True,
                                    return_tensors="pt")

        text_embeddings = self.text_encoder(text_input.input_ids.to(torch_device))[0]

        max_length = text_input.input_ids.shape[-1]
        uncond_input = self.tokenizer(
            [""] * len(prompt), padding="max_length", max_length=max_length, return_tensors="pt"
        )
        uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(torch_device))[0]

        text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
        return text_embeddings

    def diffusion_pipeline(
            self, prompt: str, x_t_start, t_start, guidance_scale,
            noise_scheduler, num_steps=50,
            verbose=False
    ):
        device = self.decoded_x_0_pred.device
        text_embeddings = self.get_text_embedding(prompt)

        with torch.no_grad():
            x_t = x_t_start.clone()
            noise_scheduler.set_timesteps(num_steps)

            timesteps = noise_scheduler.timesteps
            for i in tqdm(range(1, len(timesteps)), disable=not verbose):
                t = timesteps[i - 1]
                if t <= t_start:
                    x_t, _ = self.step(text_embeddings, t, guidance_scale, x_t, noise_scheduler)

            return self.decode(x_t)

    def decode(self, x_t):
        latents = x_t / 0.18215
        image = self.vae.decode(latents).sample
        channel_vector = torch.tensor([0.2989, 0.5870, 0.1140])
        return torch.stack([
            image[:, i] * alpha for i, alpha in enumerate(channel_vector)
        ]).sum(dim=0)

    def guided_diffusion_pipeline(
            self, prompt, guidance_scale,
            x_t_start, t_start, noise_scheduler, num_steps=50,
            hr_sinogram: Sinogram = None,
            lr_sinogram: Sinogram = None,
            lr_forward_function=None,
            batch_size=10, verbose=False, sgd_steps=50, lr=0.1,
    ):
        device = self.decoded_x_0_pred.device
        text_embeddings = self.get_text_embedding(prompt)
        hr_sinogram = hr_sinogram.to(device) if hr_sinogram is not None else None
        x_t = x_t_start.clone()
        # the lr_forward function should be batched
        lr_forward_function = torch.vmap(lr_forward_function)

        noise_scheduler.set_timesteps(num_steps)

        timesteps = noise_scheduler.timesteps
        for i in tqdm(range(1, len(timesteps)), disable=not verbose):
            t = timesteps[i - 1]
            if t <= t_start:
                with torch.no_grad():
                    noise_scheduler._init_step_index(t)

                    x_0_pred, noise_pred = self.predict_x_0(
                        text_embeddings, guidance_scale,
                        noise_scheduler, t, x_t
                    )

                if hr_sinogram is not None or lr_sinogram is not None:
                    x_0_pred = self.sinogram_guidance(
                        x_0_pred, hr_sinogram, lr_sinogram, lr_forward_function, sgd_steps=sgd_steps,
                        batch_size=batch_size, lr=lr
                    )
                with torch.no_grad():
                    # here is where we could do modifications to x_0
                    sigma = noise_scheduler.sigmas[noise_scheduler.step_index]
                    modified_noise_pred = (x_t - x_0_pred) / sigma

                    x_t = noise_scheduler.step(modified_noise_pred, t, x_t).prev_sample

        with torch.no_grad():
            return self.decode(x_t)

    def sinogram_guidance(
            self, x_0_pred,
            hr_sinogram: Sinogram,
            lr_sinogram: Sinogram,
            lr_forward_function=lambda x: x,
            sgd_steps=50,
            batch_size=10,
            lr=0.1, verbose=False
    ):
        if type(lr) != list:
            lr = [lr]
            sgd_steps = [sgd_steps]

        mse = torch.nn.MSELoss()
        device = x_0_pred.device
        with torch.no_grad():
            # decoded_x_0 = self.decode(x_0_pred)[:, 0]  # we take only the first channel
            # self.decoded_x_0_pred *= 0
            # self.decoded_x_0_pred += decoded_x_0

            self.decoded_x_0_pred *= 0
            self.decoded_x_0_pred += x_0_pred

        # we use the target model to create the target sinograms
        hr_sinogram = hr_sinogram.to(device) if hr_sinogram is not None else None

        for lr_instance, sgd_steps_instance in zip(lr, sgd_steps):

            optimizer = optim.AdamW(
                self.parameters(),
                lr=lr_instance
            )
            loss = torch.tensor(0.).to(device)
            iterator = tqdm(range(sgd_steps_instance), disable=not verbose)
            for _ in iterator:
                self.zero_grad()
                loss = torch.tensor(0.).to(device)

                bs = len(self.decoded_x_0_pred)

                if hr_sinogram is not None:
                    random_integers = [x.item() for x in torch.randperm(len(hr_sinogram.angles))[:batch_size]]

                    sinogram = compute_sinogram(self.decode(self.decoded_x_0_pred),
                                                hr_sinogram.angles[random_integers])
                    target_sino = hr_sinogram[random_integers]
                    loss += mse(sinogram.reshape(bs, *target_sino.shape), target_sino.expand(bs, *target_sino.shape))

                if lr_sinogram is not None:
                    random_integers = [x.item() for x in torch.randperm(len(lr_sinogram.angles))[:batch_size]]

                    sinogram = compute_sinogram(lr_forward_function(self.decode(self.decoded_x_0_pred)),
                                                lr_sinogram.angles[random_integers])
                    target_sino = lr_sinogram[random_integers]
                    loss += mse(sinogram.reshape(bs, *target_sino.shape), target_sino.expand(bs, *target_sino.shape))

                iterator.set_postfix({"loss": f"{loss.item():.3f}"})
                loss.backward()
                optimizer.step()
            # plt.figure()
            # plt.imshow(self.decode(self.decoded_x_0_pred.detach())[0].cpu().clip(0, 1), cmap='gray')
        return self.decoded_x_0_pred.detach().clone()

        # with torch.no_grad():
        #     # convert back to 3 channels
        #     latent = self.vae.encode(torch.stack(3 * [self.decoded_x_0_pred], 1)).latent_dist.sample()
        #     latent *= 0.18215
        #     # plt.figure()
        #     # plt.imshow(self.decoded_x_0_pred.cpu()[0].detach())
        #     return latent

